{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3615813b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bd4d4395",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(device(type='cpu'), device(type='mps'))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.device('cpu'), torch.device('mps')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "f18134b9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查询Mac是否支持mps加速\n",
    "torch.backends.mps.is_available()\n",
    "# 拥有cuda是 torch.cuda.is_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c0018d88",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<functools._lru_cache_wrapper at 0x10cbc04c0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.backends.mps.is_macos13_or_newer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "447ad376",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='mps')"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_gpu():\n",
    "    if torch.has_mps == True:\n",
    "        return torch.device('mps')\n",
    "    return torch.device('cpu')\n",
    "\n",
    "get_gpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f4e1dfa8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cpu')"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.tensor([1, 2, 3])\n",
    "x.device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "11d728cc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        ...,\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.]], device='mps:0')"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = torch.ones((10000, 256), device=get_gpu())\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "937b9c7c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.4404, 0.4217, 0.4726, 0.6862, 0.6965, 0.5534, 0.9766, 0.5027, 0.3039,\n",
       "         0.8260, 0.9021, 0.1612, 0.8956, 0.7418, 0.1465, 0.9608, 0.4971, 0.4971,\n",
       "         0.7647, 0.3747, 0.3643, 0.6594, 0.3533, 0.6164, 0.2622, 0.0846, 0.4041,\n",
       "         0.3695, 0.1467, 0.4649, 0.3050, 0.1510, 0.1274, 0.4837, 0.7521, 0.2807,\n",
       "         0.3866, 0.3326, 0.7983, 0.9178, 0.0586, 0.7754, 0.1751, 0.4494, 0.5030,\n",
       "         0.0317, 0.0894, 0.7983, 0.4745, 0.6201, 0.5932, 0.2854, 0.8374, 0.9212,\n",
       "         0.6672, 0.9920, 0.8630, 0.0513, 0.5349, 0.4433, 0.4852, 0.8259, 0.3185,\n",
       "         0.0475, 0.4764, 0.6021, 0.5176, 0.4714, 0.9581, 0.4980, 0.0900, 0.9876,\n",
       "         0.4223, 0.0098, 0.9998, 0.8665, 0.3148, 0.1312, 0.7308, 0.7441, 0.1197,\n",
       "         0.4166, 0.5202, 0.3649, 0.0751, 0.8900, 0.1096, 0.8052, 0.9739, 0.2592,\n",
       "         0.0543, 0.9328, 0.5642, 0.3955, 0.0053, 0.0720, 0.9721, 0.2248, 0.1500,\n",
       "         0.2527, 0.0968, 0.0396, 0.0774, 0.5664, 0.4062, 0.0878, 0.1830, 0.7540,\n",
       "         0.3480, 0.4695, 0.1819, 0.1459, 0.6234, 0.3386, 0.4644, 0.3767, 0.9186,\n",
       "         0.8682, 0.0293, 0.8719, 0.1429, 0.0170, 0.2214, 0.5416, 0.7260, 0.1026,\n",
       "         0.7185, 0.3056, 0.2850, 0.5109, 0.0793, 0.2054, 0.7571, 0.2763, 0.7288,\n",
       "         0.5307, 0.9880, 0.7968, 0.7359, 0.5492, 0.3051, 0.4220, 0.1846, 0.6760,\n",
       "         0.8034, 0.7733, 0.1446, 0.6821, 0.6073, 0.6160, 0.2460, 0.6032, 0.9176,\n",
       "         0.2707, 0.5521, 0.7964, 0.0259, 0.8587, 0.3344, 0.9941, 0.7083, 0.9051,\n",
       "         0.0373, 0.8685, 0.7513, 0.9617, 0.9932, 0.1029, 0.7672, 0.0276, 0.0108,\n",
       "         0.9988, 0.2598, 0.0094, 0.9439, 0.8270, 0.0276, 0.8958, 0.6094, 0.4946,\n",
       "         0.2126, 0.3605, 0.1517, 0.3275, 0.9006, 0.7698, 0.7879, 0.1666, 0.2841,\n",
       "         0.2751, 0.2193, 0.6493, 0.9821, 0.5025, 0.1613, 0.3354, 0.5233, 0.5400,\n",
       "         0.8065, 0.9755, 0.1039, 0.8725, 0.3337, 0.1117, 0.9076, 0.4467, 0.3097,\n",
       "         0.1118, 0.6011, 0.2847, 0.4288, 0.4688, 0.2298, 0.3955, 0.5761, 0.3298,\n",
       "         0.7178, 0.8118, 0.4767, 0.7934, 0.4939, 0.1358, 0.6722, 0.8837, 0.5069,\n",
       "         0.6337, 0.4850, 0.0890, 0.8847, 0.2077, 0.6595, 0.6687, 0.4300, 0.7572,\n",
       "         0.3836, 0.3818, 0.1707, 0.3241, 0.5901, 0.0999, 0.9830, 0.8236, 0.6499,\n",
       "         0.8410, 0.7003, 0.0129, 0.2821, 0.0025, 0.8661, 0.5728, 0.5863, 0.4571,\n",
       "         0.0834, 0.9779, 0.4181, 0.9613]], device='mps:0')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y = torch.rand(1, 256, device=get_gpu())\n",
    "Y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5173aa71",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        ...,\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.],\n",
       "        [1., 1., 1.,  ..., 1., 1., 1.]], device='mps:0')"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z = X.to('mps')\n",
    "Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "b6c74a5e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[1.4404, 1.4217, 1.4726,  ..., 1.9779, 1.4181, 1.9613],\n",
       "        [1.4404, 1.4217, 1.4726,  ..., 1.9779, 1.4181, 1.9613],\n",
       "        [1.4404, 1.4217, 1.4726,  ..., 1.9779, 1.4181, 1.9613],\n",
       "        ...,\n",
       "        [1.4404, 1.4217, 1.4726,  ..., 1.9779, 1.4181, 1.9613],\n",
       "        [1.4404, 1.4217, 1.4726,  ..., 1.9779, 1.4181, 1.9613],\n",
       "        [1.4404, 1.4217, 1.4726,  ..., 1.9779, 1.4181, 1.9613]],\n",
       "       device='mps:0')"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Y + Z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c22368cb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Z.is_mps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d7095c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Linear(256, 128),\n",
    "    nn.Sigmoid(),\n",
    "    nn.Linear(128, 64),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(64, 16),\n",
    "    nn.Tanh(),\n",
    "    nn.Linear(16, 1)\n",
    ")\n",
    "net = net.to(device=get_gpu())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "32237e39",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([10000, 256])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d3c4e4de",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.0235],\n",
       "        [0.0235],\n",
       "        [0.0235],\n",
       "        ...,\n",
       "        [0.0235],\n",
       "        [0.0235],\n",
       "        [0.0235]], device='mps:0', grad_fn=<LinearBackward0>)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ffc7bc22",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='mps', index=0)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net[0].weight.data.device"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python(d2l)",
   "language": "python",
   "name": "d2l"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
